import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import rcParams
from upsetplot import from_memberships, plot
from matplotlib import pyplot
from itertools import islice
import seaborn as sns
import json
import csv
import ast
from evaluate import load
import spacy
import inflect
import re
import time

nlp = spacy.load('en_core_web_sm')

infl = inflect.engine()

LMs_families = ['bloom', 'gpt2', 'xlnet', 'bart', 'llama2']

LMs = ['bigscience/bloom-560m', 'bigscience/bloom-3b', 
       'gpt2', 'gpt2-medium', 
       'xlnet-base-cased', 'xlnet-large-cased', 
       'facebook/bart-base', 'facebook/bart-large', 
       'meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-13b-hf'] 

LMs_columns_names = {'bigscience/bloom-560m': 'bloom-560m', 
                     'bigscience/bloom-3b': 'bloom-3b',
                     'facebook/bart-base': 'bart-base', 
                     'facebook/bart-large': 'bart-large', 
                     'meta-llama/Llama-2-7b-hf,': 'llama-2-7b',
                     'meta-llama/Llama-2-13b-hf': 'llama-2-13b'}

LMs_columns = ['bloom-560m', 'bloom-3b', 
               'gpt2-base', 'gpt2-medium', 
               'xlnet-base-cased', 'xlnet-large-cased', 
               'bart-base', 'bart-large', 
               'llama-2-7b', 'llama-2-13b'] 

LMs_base = ['bloom-560m', 'gpt2', 'xlnet-base-cased', 'bart-base', 'llama-2-7b']

LMs_large = ['bloom-3b', 'gpt2-medium', 'xlnet-large-cased', 'bart-large', 'llama-2-13b']

unique_categories = ['culture', 'disabled', 'gender', 'race']

num_categories = len(unique_categories)

def get_unique(df, column):
    unique = set()
    for element in df[column].unique():
        values = ast.literal_eval(element)
        unique.update(values)
    unique = list(unique)
    print(len(unique))
    return unique

def split_values(values):
    split_values = values.split(',')
    split_values = [value.strip().strip('[]').strip('"') for value in split_values]
    return set(split_values)

def create_dict(df, unique_categories, column):
    result = {}
    for category in unique_categories:
        result[category] = []
        for index, row in df.iterrows():
            if category in row['targetCategory']:
                for value in split_values(row[column]):
                    value = value.lower()
                    if value not in result[category]:
                        result[category].append(value)
    for key, value in result.items():
        print(f"{key}: {len(value)}")
    return result

def split_rows(df, column):
    new_df = pd.DataFrame(columns=df.columns)
    for index, row in df.iterrows():
        for value in split_values(row[column]):
            new_row = row.copy()
            new_row[column] = [value]
            new_df = pd.concat([new_df, pd.DataFrame(new_row).T])
    new_df.reset_index(drop=True, inplace=True)
    return new_df

def filter(all_targets, df, dis=False): 
  df = df[["TERM", "POS"]]
  POS = df.POS.unique()
  POS_d = {}
  POS_d['n'] = 0 
  i = 1
  for p in POS:
    if p != 'n': 
      POS_d[p] = i 
      i += 1 
  df=df.replace({"POS": POS_d})
  dropped = len(df)
  df = df.sort_values("POS", ascending=True).drop_duplicates(subset=["TERM"]) # keeping n as preference 
  dropped = dropped - len(df)
  print(dropped)
  POS_d = {v: k for k, v in POS_d.items()}
  df=df.replace({"POS": POS_d})
  for index, row in df.iterrows():
    if row["POS"] != "adj": # to plural cause stereotypes as a result of generalisations
        all_targets.append(row["TERM"])
        term = row["TERM"]
        plural_term = infl.plural(term)
        df.at[index, "TERM"] = plural_term
    else: # for adj added "people"
        all_targets.append(row["TERM"])
        df.at[index, "TERM"] = row["TERM"] + " people"  
        df.at[index, "POS"] = "n"
    if dis:      
      if row["POS"] == "pp": 
        all_targets.append(row["TERM"])
        df.at[index, "TERM"] = "people " + row["TERM"]  
        df.at[index, "POS"] = "n"
  print(len(df))
  print(df.head)
  return df["TERM"].tolist(), all_targets

def filter_exceptions(all_targets, df, race=False, country=False): 
  if race:
    for index, row in df.iterrows():
      all_targets.append(row["TERM"])
      df = df[["TERM"]]
      df.at[index, "TERM"] = row["TERM"] + " people"    
  else:
    if country:
        dropped = len(df)
        df = df.drop_duplicates(subset=["COUNTRY_ADJ"])  
        df = df[["COUNTRY_ADJ"]]
        df.rename(columns={"COUNTRY_ADJ": "TERM"}, inplace=True)
        for index, row in df.iterrows():
            all_targets.append(row["TERM"])
            term = row["TERM"]
            plural_term = infl.plural(term)
            df.at[index, "TERM"] = plural_term
        dropped = dropped - len(df)
        print(dropped)
    else:
        dropped = len(df)
        df = df.drop_duplicates(subset=["REGION_ADJ"])  
        df = df[["REGION_ADJ"]]
        df.rename(columns={"REGION_ADJ": "TERM"}, inplace=True)
        for index, row in df.iterrows():
            all_targets.append(row["TERM"])
            term = row["TERM"]
            plural_term = infl.plural(term)
            df.at[index, "TERM"] = plural_term
        dropped = dropped - len(df)
        print(dropped)
  print(len(df))
  print(df.head)
  return df["TERM"].tolist(), all_targets

def find_subject(sentence):
    doc = nlp(sentence)
    subjects = []
    for tok in doc:
        if tok.dep_ in ['subj', 'nsubj', 'nsubjpass', 'csubj', 'csubjpass' , 'conj']:
            subject = [tok.text]
            for descendant in tok.subtree:
                if descendant.pos_ in ('ADJ', 'NOUN', 'ADP', 'CCONJ', 'PRP$') and descendant != tok:
                    if descendant.pos_ == 'ADJ' or descendant.pos_ == 'NOUN':
                        subject.insert(0, descendant.text)
                    else: 
                        subject.append(descendant.text)
            for left in reversed(list(tok.lefts)):
                if left.dep_ == 'poss' or left.dep_ == 'det' or left.dep_ == 'compound':
                    subject.insert(0, left.text)
            subjects.append(' '.join(subject))
        if subjects != []:
            return ' '.join(subjects) if subjects else None

def processing(text, all_targets): 
    punctuation_marks = ['.', ',', '!', '?', ':', ';', '"', '“', '”', '-', '–', '(', ')', '[', ']', '{', '}', '...', '\\', '\\\\'] 
    punctuation_string = ''.join(punctuation_marks)
    modified_text = re.sub(r'[' + re.escape(punctuation_string) + r']', '', text) # strip punctuation
    modified_text = modified_text.replace('  ', ' ')
    modified_text = modified_text.strip()
    doc = nlp(modified_text)
    if len(doc) > 0:
        token = doc[0]
        if doc[0].dep_ == 'advmod' or doc[0].pos_ == 'ADV' and len(doc) > 1:
            token = doc[1]
        if token.pos_ == 'VERB' or token.pos_ == 'AUX':
            if token.tag_ == 'VBZ':
                modified_verb = infl.plural_verb(token.text) # change verb from singular to plural
                modified_text = modified_text.replace(token.text, modified_verb)
        else:
            return None        
        has_target = any(substring in modified_text for substring in all_targets) 
        has_terms_to_filter = any(substring in modified_text for substring in ['slur', 'derogatory', 'derogartory', 'offensive', 'term', 'alluding', 'called', 'being', 'using', 'used', 'marginalized', 'marginalizes', 'marginalize', 'trivializes', 'trivialize', 'insinuates', 'insinuate', 'implies', 'imply', 'expresses', 'express', 'dislikes', 'dislike', 'put down', 'insult', 'seen as', 'assocaite', 'associates', 'associate', 'n-word', 'n word', 'nigga', 'n*gga', 'niggas', 'n*ggas', 'niggers', 'n*ggers', 'faggot', 'f*ggot', 'faggots', 'f*ggots', 'physically', 'syria', 'chigago', 'holocaust', 'group', 'groups', 'minority', 'minorities']) 
        if has_target or has_terms_to_filter: # if the statement contains a specific target, exclude that target to avoid illogical and repetitive phrases 
            return None    
        return modified_text
    else:
        return None  